{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Visualization of the `brax` Experiment Results\n",
    "\n",
    "Using this notebook, you can see the visualization of the agent trained by the notebook [Brax_Experiments_with_PGPE.ipynb](Brax_Experiments_with_PGPE.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Name of the pickle file saved by PicklingLogger goes here:\n",
    "FNAME = \"...\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Name of the environment\n",
    "ENV_NAME = \"brax::humanoid\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(FNAME, \"rb\") as f:\n",
    "    loaded = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The unpickled object is a dictionary with these keys:\n",
    "list(loaded.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loaded center solution\n",
    "center = loaded[\"center\"]\n",
    "center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loaded policy network\n",
    "policy = loaded[\"policy\"]\n",
    "policy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "Below, we put the values of the center solution into the policy network as parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.utils.vector_to_parameters(center, policy.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Visualizing the trained policy\n",
    "\n",
    "Now that we have our final policy, we manually run and visualize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert ENV_NAME.startswith(\"brax::\"), \"This notebook can only work with brax environments\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "GOT_OLD_BRAX_ENV = ENV_NAME.startswith(\"brax::old::\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "if GOT_OLD_BRAX_ENV:\n",
    "    import brax.v1 as brax\n",
    "    import brax.v1.envs as brax_envs\n",
    "    import brax.v1.jumpy as jp\n",
    "    from brax.v1.jumpy import random_prngkey\n",
    "    from brax.v1.io import html, image\n",
    "else:\n",
    "    import brax\n",
    "    import brax.envs as brax_envs\n",
    "    from jax.random import PRNGKey as random_prngkey\n",
    "    from brax.io import html, image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML, Image\n",
    "from typing import Iterable, Optional\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "Below, we define a utility function named `use_policy(...)`.\n",
    "\n",
    "The expected arguments of `use_policy(...)` are as follows:\n",
    "\n",
    "- `torch_module`: The policy object, expected as a `nn.Module` instance.\n",
    "- `x`: The observation, as an iterable of real numbers.\n",
    "- `h`: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.\n",
    "\n",
    "The return values of this function are as follows:\n",
    "\n",
    "- The action recommended by the policy, as a numpy array\n",
    "- The hidden state of the module, if the module is a recurrent one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:\n",
    "    x = torch.as_tensor(np.asarray(x), dtype=torch.float32)\n",
    "    if h is None:\n",
    "        result = torch_module(x)\n",
    "    else:\n",
    "        result = torch_module(x, h)\n",
    "\n",
    "    if isinstance(result, tuple):\n",
    "        x, h = result\n",
    "        x = x.numpy()\n",
    "    else:\n",
    "        x = result.numpy()\n",
    "        h = None\n",
    "        \n",
    "    return x, h"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "We now initialize a new instance of our brax environment, and trigger the jit compilation on its `reset` and `step` methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = brax_envs.create(env_name=\"humanoid\")\n",
    "reset = jax.jit(env.reset)\n",
    "step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "Below we run our policy and collect the states of the episodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = random.randint(0, (2 ** 32) - 1)\n",
    "state = reset(rng=random_prngkey(seed=seed))\n",
    "\n",
    "h = None\n",
    "states = []\n",
    "cumulative_reward = 0.0\n",
    "\n",
    "while True:\n",
    "    action, h = use_policy(policy, state.obs, h)\n",
    "    action = jnp.asarray(action)\n",
    "\n",
    "    state = step(state, action)\n",
    "    cumulative_reward += float(state.reward)\n",
    "    states.append(state)\n",
    "    if np.abs(np.array(state.done)) > 1e-4:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "Length of the episode and the total reward:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(states), cumulative_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pipeline_state(state):\n",
    "    if hasattr(state, \"qp\"):\n",
    "        return state.qp\n",
    "    elif hasattr(state, \"pipeline_state\"):\n",
    "        return state.pipeline_state\n",
    "    else:\n",
    "        assert False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {},
   "source": [
    "Visualization of the policy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_states = [pipeline_state(state) for state in states]\n",
    "\n",
    "if hasattr(env.sys, \"tree_replace\"):\n",
    "    env_sys = env.sys.tree_replace({'opt.timestep': env.dt})\n",
    "else:\n",
    "    env_sys = env.sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "HTML(html.render(env_sys, pipeline_states))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
